LoRA and PEFT #
Parameter-Efficient Fine-Tuning (PEFT) lets you adapt large models by training a small number of additional parameters while keeping the base model frozen. The most widely used PEFT method is LoRA (Low-Rank Adaptation)
Understanding LoRA #
- The most widely adopted PEFT method
- LoRA works by adding pairs of rank decomposition matrices to transformer layers
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that freezes the pre-trained model weights and injects trainable rank decomposition matrices into the model’s layers. Instead of training all model parameters during fine-tuning, LoRA decomposes the weight updates into smaller matrices through low-rank decomposition, significantly reducing the number of trainable parameters while maintaining model performance. For example, when applied to GPT-3 175B, LoRA reduced trainable parameters by 10,000x and GPU memory requirements by 3x compared to full fine-tuning.
Loading LoRA adapters #
from transformers import AutoModelForCausalLM
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained("<base_model_name>")
peft_model_id = "<peft_adapter_id>"
model = PeftModel.from_pretrained(base_model, peft_model_id)
Merging LoRA Adapters #
- Single model with the combined weights, eliminating the need to load adapters separately during inference.
import torch
from transformers import AutoModelForCausalLM
from peft import PeftModel
# 1. Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
"base_model_name",
dtype=torch.bfloat16,
device_map="auto"
)
# 2. Load the PEFT model with adapter
peft_model = PeftModel.from_pretrained(
base_model,
"path/to/adapter",
dtype=torch.bfloat16
)
# 3. Merge adapter weights with base model
try:
merged_model = peft_model.merge_and_unload()
except RuntimeError as e:
print(f"Merging failed: {e}")
# Implement fallback strategy or memory optimization
# 4. Save the merged model
merged_model.save_pretrained("path/to/save/merged_model")
tokenizer.save_pretrained("path/to/save/merged_model")